热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

评测|CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测|CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

作者:Max Woolf

机器之心编译

参与:Jane W、吴攀

Keras 是由 François Chollet 维护的深度学习高级开源框架,它的底层基于构建生产级质量的深度学习模型所需的大量设置和矩阵代数。Keras API 的底层基于像 Theano 或谷歌的 TensorFlow 的较低级的深度学习框架。Keras 可以通过设置 flag 自由切换后端(backend)引擎 Theano/TensorFlow;而不需要更改前端代码。

虽然谷歌的 TensorFlow 已广受关注,但微软也一直在默默地发布自己的机器学习开源框架。例如 LightGBM 框架,可以作为著名的 xgboost 库的替代品。例如几周前发布的 CNTK v2.0(Microsoft Cognitive Toolkit),它与 TensorFlow 相比,显示出在准确性和速度方面的强劲性能。参阅机器之心报道《开源 | 微软发行 Cognitive Toolkit 2.0 完整版:从性能更新到应用案例》。

CNTK v2.0 还有一个关键特性:兼容 Keras。就在上周,对 CNTK 后端的支持被合并到官方的 Keras 资源库(repository)中。

Hacker News 论坛对于 CNTK v2.0 也有评论(https://news.ycombinator.com/item?id=14470967),微软员工声称,将 Keras 的后端由 TensorFlow 改为 CNTK 可以显著提升性能。那么让我们来检验这句话的真伪吧。

在云端进行深度学习

在云端设置基于 GPU 的深度学习实例令人惊讶地被忽视了。大多数人建议使用亚马逊 AWS 服务,它包含所有可用的 GPU 驱动,只需参照固定流程(https://blog.keras.io/running-jupyter-notebooks-on-gpu-on-aws-a-starter-guide.html)设置远程操作。然而,对于 NVIDIA Tesla K80 GPU,亚马逊 EC2 收费 $0.90/小时(不按时长比例收费);对于相同的 GPU,谷歌 Compute Engine(GCE)收费 $0.75/小时(按分钟比例收费),这对于需要训练许多小时的深度学习模型是非常显著的弱点。

要使用 GCE,你必须从一个空白的 Linux 实例中设置深度学习的驱动和框架。我使用 Keras 进行了第一次尝试(http://minimaxir.com/2017/04/char-embeddings/),但这并不有趣。不过,我最近受到 Durgesh Mankekar 文章(https://medium.com/google-cloud/containerized-jupyter-notebooks-on-gpu-on-google-cloud-8e86ef7f31e9)的启发,该文章采用了 Docker 容器这种更现代的方法来管理依赖关系,该文章还介绍了名为 Dockerfile 的安装脚本和容器与 Keras 必需的深度学习驱动/框架。Docker 容器可以使用 nvidia-docker 进行加载,这可以让 Docker 容器访问主机上的 GPU。在容器中运行深度学习脚本只需运行 Docker 命令行。当脚本运行完后,会自动退出容器。这种方法恰巧保证了每次执行是独立的;这为基准评估/重复执行提供了理想的环境。

我稍微调整了 Docker 容器(GitHub 网址 https://github.com/minimaxir/keras-cntk-docker),容器安装了 CNTK、与 CNTK 兼容的 Keras 版本,并设置 CNTK 为 Keras 的默认后端。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

基准方法

Keras 的官方案例(https://github.com/fchollet/keras/tree/master/examples)非常全面,涉及多种现实中的深度学习问题,并能完美地模拟 Keras 在不同模型的性能。我选取了强调不同神经网络架构的几个例子(https://github.com/minimaxir/keras-cntk-benchmark/tree/master/test_files),并添加了一个自定义 logger,它能够输出含有模型性能和训练时间进程的 CSV 文件。

如前所述,只需要设置一个 flag 就能方便地切换后端引擎。即使 Docker 容器中 Keras 的默认后端是 CNTK,一个简单的 -e KERAS_BACKEND ='tensorflow' 命令语句就可以切换到 TensorFlow。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

我写了一个 Python 基准脚本(https://github.com/minimaxir/keras-cntk-benchmark/blob/master/keras_cntk_benchmark.py)(在主机上运行)来管理并运行 Docker 容器中的所有例子,它同时支持 CNTK 和 TensorFlow 后端,并用 logger 收集生成的日志。

下面是不同数据集的结果。

IMDb 评论数据集

IMDb 评论数据集(http://ai.stanford.edu/~amaas/data/sentiment/)是用于情感分析的著名的自然语言处理(NLP)基准数据集。数据集中的 25000 条评论被标记为「积极」或「消极」。在深度学习成为主流之前,优秀的机器学习模型在测试集上达到大约 88% 的分类准确率。

第一个模型方法(imdb_bidirectional_lstm.py)使用了双向 LSTM(Bidirectional LSTM),它通过词序列对模型进行加权,同时采用向前(forward)传播和向后(backward)传播的方法。

首先,我们来看一下在训练模型时的不同时间点测试集的分类准确率:

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

通常,准确率随着训练的进行而增加;双向 LSTM 需要很长时间来训练才能得到改进的结果,但至少这两个框架都是同样有效的。

为了评估算法的速度,我们可以计算训练一个 epoch 所需的平均时间。每个 epoch 的时间大致相同;测量结果真实平均值用 95%的置信区间表示,这是通过非参数统计的 bootstrapping 方法得到的。双向 LSTM 的计算速度:

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

哇,CNTK 比 TensorFlow 快很多!虽然没有比 LSTM 的基准测试(https://arxiv.org/abs/1608.07249)快 5-10 倍,但是仅通过设置后端 flag 就几乎将运行时间减半就已经够令人震惊了。

接下来,我们用同样的数据集测试 fasttext 方法(imdb_fasttext.py)。fasttext 是一种较新的算法,可以计算词向量嵌入(word vector Embedding)的平均值(不论顺序),但是即使在使用 CPU 时也能得到令人难以置信的速度和效果,如同 Facebook 官方对 fasttext 的实现(https://github.com/facebookresearch/fastText)一样。(对于此基准,我倾向于使用二元语法模型/bigram)

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

由于模型简单,这两种框架的准确率几乎相同,但在使用词嵌入的情况下,TensorFlow 速度更快。(不管怎样,fasttext 明显比双向 LSTM 方法快得多!)此外,fasttext 打破了 88%的基准,这可能值得考虑在其它机器学习项目中推广。

MNIST 数据集

MNIST 数据集(http://yann.lecun.com/exdb/mnist/)是另一个著名的手写数字数据集,经常用于测试计算机视觉模型(60000 个训练图像,10000 个测试图像)。一般来说,良好的模型在测试集上可达到 99%以上的分类准确率。

多层感知器(multilayer perceptron/MLP)方法(mnist_mlp.py)仅使用一个大型全连接网络,就达到深度学习魔术(Deep Learning Magic™)的效果。有时候这样就够了。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

这两个框架都能极速地训练模型,每个 epoch 只需几秒钟;在准确性方面没有明确的赢家(尽管没有打破 99%),但是 CNTK 速度更快。

另一种方法(mnist_cnn.py)是卷积神经网络(CNN),它利用相邻像素之间的固有关系建模,是一种逻辑上更贴近图像数据的架构。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

在这种情况下,TensorFlow 在准确率和速度方面都表现更好(同时也打破 99%的准确率)。

CIFAR-10

现在来研究更复杂的实际模型,CIFAR-10 数据集(https://www.cs.toronto.edu/~kriz/cifar.html)是用于 10 个不同对象的图像分类的数据集。基准脚本的架构(cifar10_cnn.py)是很多层的 Deep CNN + MLP,其架构类似于著名的 VGG-16(https://gist.github.com/baraldilorenzo/07d7802847aaad0a35d3)模型,但更简单,由于大多数人没有用来训练的超级计算机集群。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

在这种情况下,两个后端的在准确率和速度上的性能均相等。也许 CNTK 更利于 MLP,而 TensorFlow 更利于 CNN,两者的优势互相抵消。

尼采文本生成

基于 char-rnn(https://github.com/karpathy/char-rnn)的文本生成(lstm_text_generation.py)很受欢迎。具体来说,它使用 LSTM 来「学习」文本并对新文本进行抽样。在使用随机的尼采文集(https://s3.amazonaws.com/text-datasets/nietzsche.txt)作为源数据集的 Keras 例子中,该模型尝试使用前 40 个字符预测下一个字符,并尽量减少训练的损失函数值。理想情况的是损失函数值低于 1.00,并且生成的文本语法一致。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

两者的损失函数值随时间都有相似的变化(不幸的是,1.40 的损失函数值下,仍有乱码文本生成),由于 LSTM 架构,CTNK 的速度更快。

对于下一个基准测试,我将不使用官方的 Keras 示例脚本,而是使用我自己的文本生成器架构(text_generator_keras.py),详见之前关于 Keras 的文章(http://minimaxir.com/2017/04/char-embeddings)。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

我的网络避免了过早收敛,对于 TensorFlow,只需损失很小的训练速度;不幸的是,CNTK 的速度比简单模型慢了许多,但在高级模型中仍然比 TensorFlow 快得多。

以下是用 TensorFlow 训练的我的架构模型生成的文本输出:

hinks the rich man must be wholly perverity and connection of the english sin of the philosophers of the basis of the same profound of his placed and evil and exception of fear to plants to me such as the case of the will seems to the will to be every such a remark as a primates of a strong of [...]

这是用 CNTK 训练的模型输出:

(_x2js1hevjg4z_?z_aæ?q_gpmj:sn![?(f3_ch=lhw4y n6)gkh kujau momu,?!ljë7g)k,!?[45 0as9[d.68éhhptvsx jd_næi,ä_z!cwkr"_f6ë-mu_(epp [...]

等等,什么?显然,我的模型架构导致 CNTK 在预测时遇到错误,而「CNTK+简单的 LSTM」架构并没有发生这种错误。通过质量评估,我发现批归一化(batch normalization)是错误的原因,并及时提出了这个问题(https://github.com/Microsoft/CNTK/issues/1994)。

结论

综上,评价 Keras 框架是否比 TensorFlow 更好,这个判断并没有设想中的那么界限分明。两个框架的准确性大致相同。CNTK 在 LSTM/MLP 上更快,TensorFlow 在 CNN/词嵌入(Embedding)上更快,但是当网络同时实现两者时,它们会打个平手。

撇开随机错误,有可能 CNTK 在 Keras 上的运行还没有完全优化(实际上,1bit-SGD 的设置不起作用(https://github.com/Microsoft/CNTK/issues/1975)),所以未来还是有改进的空间的。尽管如此,简单地设置 flag 的效果是非常显著的,在将它们部署到生产之前,值得在 CNTK 和 TensorFlow 后端上测试 Keras 模型,以比较两者哪个更好。  评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

原文链接:http://minimaxir.com/2017/06/keras-cntk/

版权声明

本文仅代表作者观点,不代表百度立场。

阅读量: 0

0

0


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 我们


推荐阅读
  • 1.脚本功能1)自动替换jar包中的配置文件。2)自动备份老版本的Jar包3)自动判断是初次启动还是更新服务2.脚本准备进入ho ... [详细]
  • 本文介绍了Hyperledger Fabric外部链码构建与运行的相关知识,包括在Hyperledger Fabric 2.0版本之前链码构建和运行的困难性,外部构建模式的实现原理以及外部构建和运行API的使用方法。通过本文的介绍,读者可以了解到如何利用外部构建和运行的方式来实现链码的构建和运行,并且不再受限于特定的语言和部署环境。 ... [详细]
  • 使用在线工具jsonschema2pojo根据json生成java对象
    本文介绍了使用在线工具jsonschema2pojo根据json生成java对象的方法。通过该工具,用户只需将json字符串复制到输入框中,即可自动将其转换成java对象。该工具还能解析列表式的json数据,并将嵌套在内层的对象也解析出来。本文以请求github的api为例,展示了使用该工具的步骤和效果。 ... [详细]
  • Voicewo在线语音识别转换jQuery插件的特点和示例
    本文介绍了一款名为Voicewo的在线语音识别转换jQuery插件,该插件具有快速、架构、风格、扩展和兼容等特点,适合在互联网应用中使用。同时还提供了一个快速示例供开发人员参考。 ... [详细]
  • flowable工作流 流程变量_信也科技工作流平台的技术实践
    1背景随着公司业务发展及内部业务流程诉求的增长,目前信息化系统不能够很好满足期望,主要体现如下:目前OA流程引擎无法满足企业特定业务流程需求,且移动端体 ... [详细]
  • 本文介绍了网页播放视频的三种实现方式,分别是使用html5的video标签、使用flash来播放以及使用object标签。其中,推荐使用html5的video标签来简单播放视频,但有些老的浏览器不支持html5。另外,还可以使用flash来播放视频,需要使用object标签。 ... [详细]
  • Servlet多用户登录时HttpSession会话信息覆盖问题的解决方案
    本文讨论了在Servlet多用户登录时可能出现的HttpSession会话信息覆盖问题,并提供了解决方案。通过分析JSESSIONID的作用机制和编码方式,我们可以得出每个HttpSession对象都是通过客户端发送的唯一JSESSIONID来识别的,因此无需担心会话信息被覆盖的问题。需要注意的是,本文讨论的是多个客户端级别上的多用户登录,而非同一个浏览器级别上的多用户登录。 ... [详细]
  • 云原生应用最佳开发实践之十二原则(12factor)
    目录简介一、基准代码二、依赖三、配置四、后端配置五、构建、发布、运行六、进程七、端口绑定八、并发九、易处理十、开发与线上环境等价十一、日志十二、进程管理当 ... [详细]
  • {moduleinfo:{card_count:[{count_phone:1,count:1}],search_count:[{count_phone:4 ... [详细]
  • 本文介绍了brain的意思、读音、翻译、用法、发音、词组、同反义词等内容,以及脑新东方在线英语词典的相关信息。还包括了brain的词汇搭配、形容词和名词的用法,以及与brain相关的短语和词组。此外,还介绍了与brain相关的医学术语和智囊团等相关内容。 ... [详细]
  • Vagrant虚拟化工具的安装和使用教程
    本文介绍了Vagrant虚拟化工具的安装和使用教程。首先介绍了安装virtualBox和Vagrant的步骤。然后详细说明了Vagrant的安装和使用方法,包括如何检查安装是否成功。最后介绍了下载虚拟机镜像的步骤,以及Vagrant镜像网站的相关信息。 ... [详细]
  • Sleuth+zipkin链路追踪SpringCloud微服务的解决方案
    在庞大的微服务群中,随着业务扩展,微服务个数增多,系统调用链路复杂化。Sleuth+zipkin是解决SpringCloud微服务定位和追踪的方案。通过TraceId将不同服务调用的日志串联起来,实现请求链路跟踪。通过Feign调用和Request传递TraceId,将整个调用链路的服务日志归组合并,提供定位和追踪的功能。 ... [详细]
  • 本文介绍了自学Vue的第01天的内容,包括学习目标、学习资料的收集和学习方法的选择。作者解释了为什么要学习Vue以及选择Vue的原因,包括完善的中文文档、较低的学习曲线、使用人数众多等。作者还列举了自己选择的学习资料,包括全新vue2.5核心技术全方位讲解+实战精讲教程、全新vue2.5项目实战全家桶单页面仿京东电商等。最后,作者提出了学习方法,包括简单的入门课程和实战课程。 ... [详细]
  • 本文介绍了Hive常用命令及其用途,包括列出数据表、显示表字段信息、进入数据库、执行select操作、导出数据到csv文件等。同时还涉及了在AndroidManifest.xml中获取meta-data的value值的方法。 ... [详细]
  • 服务网关与流量网关
    一、为什么需要服务网关1、什么是服务网关传统的单体架构中只需要开放一个服务给客户端调用,但是微服务架构中是将一个系统拆分成多个微服务,如果没有网关& ... [详细]
author-avatar
mobiledu2502856973
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有